from typing import List, Union, Dict, Any, Optional
from openai import AsyncOpenAI
import os
from typing import List, Union, Dict
from agent_tools_normalized5_v1 import *
import asyncio
import atexit
os.environ["OPENAI_API_KEY"] = "apto-llm-4o-key"

full_reference_tools = [TopIQ_FR_tool, AHIQ_tool, FSIM_tool, LPIPS_tool, DISTS_tool, WaDIQaM_FR_tool, PieAPP_tool, 
             MS_SSIM_tool, GMSD_tool, SSIM_tool, CKDN_tool, VIF_tool, PSNR_tool, VSI_tool]
    
no_reference_tools = [QAlign_tool, CLIPIQA_tool, UNIQUE_tool, HyperIQA_tool, TReS_tool, WaDIQaM_NR_tool, DBCNN_tool, 
            ARNIQA_tool, NIMA_tool, BRISQUE_tool, NIQE_tool, MANIQA_tool, LIQE_mix_tool]

_GLOBAL_OPENAI_CLIENT = AsyncOpenAI()
_CLOSED = False

@atexit.register
def _shutdown_openai_client():
    global _CLOSED
    if _CLOSED:
        return
    try:
        loop = asyncio.get_event_loop()
        if loop.is_closed():
            asyncio.run(_GLOBAL_OPENAI_CLIENT.aclose())
        else:
            loop.run_until_complete(_GLOBAL_OPENAI_CLIENT.aclose())
    except Exception as e:
        print(f"[Shutdown Warning] Failed to close OpenAI client: {e}")
    _CLOSED = True


# ======================== Message Wrapper ========================
class Message:
    def __init__(self, role: str, content: Union[str, List[Dict]]):
        self.role = role
        self.content = content

    def to_dict(self):
        return {"role": self.role, "content": self.content}

def tool_to_openai_function(tool):
    return {
        "type": "function",
        "function": {
            "name": tool.name,
            "description": tool.description or "",
            "parameters": tool.args_schema.schema() if hasattr(tool, "args_schema") else {},
        }
    }

class ChatModel:
    def __init__(self, model: str = "gpt-4o", temperature: float = 0.0, api_key: Optional[str] = None):
        self.model = model
        self.temperature = temperature
        self.client = AsyncOpenAI(api_key=api_key)
        self.bound_tools = None

    def bind_tools(self, tools: List[Any]) -> "ChatModel":
        self.bound_tools = tools
        return self

    async def ainvoke(self, messages: List[Message], **kwargs) -> Any:
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "messages": [m.to_dict() for m in messages],
        }
        if self.bound_tools:
            payload["tools"] = [tool_to_openai_function(t) for t in self.bound_tools]
        payload.update(kwargs)
        response = await self.client.chat.completions.create(**payload)
        return response.choices[0].message.content

    async def ainvoke_full(self, messages: List[Message], **kwargs) -> Any:
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "messages": [m.to_dict() for m in messages],
        }
        if self.bound_tools:
            payload["tools"] = [tool_to_openai_function(t) for t in self.bound_tools]
        payload.update(kwargs)
        response = await self.client.chat.completions.create(**payload)
        return response

    async def aclose(self):
        # OpenAI AsyncOpenAI client does not need to be closed explicitly
        pass

def get_llm_with_tools(ref_path=None, model_name="gpt-4o", temperature=0.0) -> ChatModel:
    is_full_reference = bool(ref_path)
    tool_list = full_reference_tools if is_full_reference else no_reference_tools

    base_llm = ChatModel(model=model_name, temperature=temperature)
    return base_llm.bind_tools(tool_list)

